#!/usr/bin/env python
# -*-coding=utf-8-*-



import os
import sys
base_dir = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
sys.path.append(base_dir)

import time
import datetime
import tensorflow as tf

from text_rnn.config import  TextCNNConfig
from text_rnn.model import TextRNN
from utils import train_utils
from utils import word2vec

# Model Hyperparameters
tf.flags.DEFINE_integer("embedding_dim", 100, "Dimensionality of character embedding (default: 128)")
tf.flags.DEFINE_float("dropout_keep_prob", 0.5, "Dropout keep probability (default: 0.5)")
tf.flags.DEFINE_integer("num_hidden", 100, "the num of hidden ")
tf.flags.DEFINE_float("l2_lambda", 0.02, "L2 regularizaion lambda (default: 0.0)")
tf.flags.DEFINE_float("learning_rate", 0.01, "learning rate")

# Training parameters
tf.flags.DEFINE_string("train_data_path", "", "train data")
tf.flags.DEFINE_float("train_percent", 0.9, "the percent of train data")
tf.flags.DEFINE_integer("num_classes", 2, "the num of classes")

tf.flags.DEFINE_string("pre_embedding", True, "whether to use pre embedding or not.")
tf.flags.DEFINE_string("embed_file", "", "embed file")

tf.flags.DEFINE_integer("batch_size", 256, "Batch Size (default: 64)")
tf.flags.DEFINE_integer("num_epochs", 200, "Number of training epochs (default: 200)")

tf.flags.DEFINE_integer("evaluate_every", 80, "Evaluate model on dev set after this many steps (default: 100)")
tf.flags.DEFINE_integer("checkpoint_every", 80, "Save model after this many steps (default: 100)")

tf.flags.DEFINE_integer("max_sent_length", 20, "max sentence length")
tf.flags.DEFINE_integer("decay_steps", 10,  "how many steps before decay learning rate")
tf.flags.DEFINE_float("decay_rate", 0.75, "Rate of decay for learning rate")

# Misc Parameters
tf.flags.DEFINE_boolean("allow_soft_placement", True, "Allow device soft device placement")
tf.flags.DEFINE_boolean("log_device_placement", False, "Log placement of ops on devices")


FLAGS = tf.flags.FLAGS
FLAGS._parse_flags()

print("\nParameters:")
for attr, value in sorted(FLAGS.__flags.items()):
    print("{}={}".format(attr.upper(), value))
print("")


def create_config(vocab_sizes):
    config = TextCNNConfig()
    config.embedding_dim = FLAGS.embedding_dim
    config.sequence_length = FLAGS.max_sent_length
    config.batch_size = FLAGS.batch_size
    config.num_hidden = FLAGS.num_hidden
    config.num_classes = FLAGS.num_classes
    config.vocab_sizes = vocab_sizes
    return config


x_train, y_train, x_dev, y_dev, vocab_processor = \
    train_utils.papre_train_data(train_data_path = FLAGS.train_data_path,
                                 max_sent_length = FLAGS.max_sent_length,
                                 num_classes = FLAGS.num_classes,
                                 train_percent = FLAGS.train_percent,)


print("Vocabulary Size: {:d}".format(len(vocab_processor.vocabulary_)))
print("Train/Dev split: {:d}/{:d}".format(len(y_train), len(y_dev)))

vocab_sizes = len(vocab_processor.vocabulary_)
model_config = create_config(vocab_sizes)
print(model_config.__dict__)

# training
with tf.Graph().as_default():
    session_conf = tf.ConfigProto(
        allow_soft_placement = FLAGS.allow_soft_placement,
        log_device_placement = FLAGS.log_device_placement
    )

    sess = tf.Session(config = session_conf)
    with sess.as_default():
        model = TextRNN(model_config)

        timestamp = str(int(time.time()))
        out_dir = os.path.abspath(os.path.join(os.path.curdir, "runs", timestamp))
        checkpoint_dir = os.path.abspath(os.path.join(out_dir, "checkpoints"))
        checkpoint_prefix = os.path.join(checkpoint_dir, "model")

        if not os.path.exists(checkpoint_dir):
            os.makedirs(checkpoint_dir)

        saver = tf.train.Saver(tf.all_variables())

        print("save vocab ... ")
        vocab_processor.save(os.path.join(out_dir, "vocab"))

        # Initialize all variables
        sess.run(tf.initialize_all_variables())

        if FLAGS.pre_embedding:
            print("load pre word2vec ...")
            wv = word2vec.Word2vec()
            embed = wv.load_w2v_array(FLAGS.embed_file, vocab_processor, )
            word_embedding = tf.constant(embed, dtype=tf.float32)
            t_assign_embedding = tf.assign(model.embedding, word_embedding)
            sess.run(t_assign_embedding)

        def train_step(x_batch, y_batch):
            """
            A single training step
            """
            feed_dict = {
                model.input_x: x_batch,
                model.input_y: y_batch,
                model.dropout_keep_prob: FLAGS.dropout_keep_prob
            }
            _, step, loss, accuracy = sess.run(
                [model.train_op, model.global_step, model.loss_val, model.accuracy],
                feed_dict)
            time_str = datetime.datetime.now().isoformat()
            print("{}: step {}, loss {:g}, acc {:g}".format(time_str, step, loss, accuracy))
            return step, accuracy

        def dev_step(x_batch, y_batch, writer=None):
            """
            Evaluates model on a dev set
            """
            feed_dict = {
                model.input_x: x_batch,
                model.input_y: y_batch,
                model.dropout_keep_prob: 1.0
            }
            step, loss, accuracy, logits , predictions = sess.run(
                [model.global_step, model.loss_val, model.accuracy, model.logits, model.predictions],
                feed_dict)
            time_str = datetime.datetime.now().isoformat()
            # print(logits)
            # print(predictions)
            # print(y_batch)
            # exit(0)
            print("{}: step {}, loss {:g}, acc {:g}".format(time_str, step, loss, accuracy))
            return accuracy

        batches = train_utils.batch_iter(list(zip(x_train, y_train)), FLAGS.batch_size, FLAGS.num_epochs)

        for batch in batches:
            x_batch, y_batch = zip(*batch)
            current_step, acc = train_step(x_batch, y_batch)


            print(current_step)
            if current_step % FLAGS.evaluate_every == 0:
                print("\nEvaluation:")
                dev_acc = dev_step(x_dev, y_dev, )
                print("")
                if dev_acc > 0.95:
                    path = saver.save(sess, checkpoint_prefix, global_step=current_step)
                    print("Saved model checkpoint to {}\n".format(path))
                    print("finish train")
                    exit(0)

            if current_step % FLAGS.checkpoint_every == 0:
                path = saver.save(sess, checkpoint_prefix, global_step=current_step)
                print("Saved model checkpoint to {}\n".format(path))

            sess.run(model.epoch_increment)